{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "18AF5Ab4p6VL"
   },
   "source": [
    "# Training a simple neural network, with PyTorch data loading\n",
    "\n",
    "<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
    "\n",
    "**Copyright 2018 The JAX Authors.**\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B_XlLLpcWjkA"
   },
   "source": [
    "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
    "\n",
    "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
    "\n",
    "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "OksHydJDtbbI"
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from jax import grad, jit, vmap\n",
    "from jax import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MTVcKi-ZYB3R"
   },
   "source": [
    "## Hyperparameters\n",
    "Let's get a few bookkeeping items out of the way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "-fmWA06xYE7d"
   },
   "outputs": [],
   "source": [
    "# A helper function to randomly initialize weights and biases\n",
    "# for a dense neural network layer\n",
    "def random_layer_params(m, n, key, scale=1e-2):\n",
    "  w_key, b_key = random.split(key)\n",
    "  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
    "\n",
    "# Initialize all layers for a fully-connected neural network with sizes \"sizes\"\n",
    "def init_network_params(sizes, key):\n",
    "  keys = random.split(key, len(sizes))\n",
    "  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n",
    "\n",
    "layer_sizes = [784, 512, 512, 10]\n",
    "step_size = 0.01\n",
    "num_epochs = 8\n",
    "batch_size = 128\n",
    "n_targets = 10\n",
    "params = init_network_params(layer_sizes, random.key(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BtoNk_yxWtIw"
   },
   "source": [
    "## Auto-batching predictions\n",
    "\n",
    "Let us first define our prediction function. Note that we're defining this for a _single_ image example. We're going to use JAX's `vmap` function to automatically handle mini-batches, with no performance penalty."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "7APc6tD7TiuZ"
   },
   "outputs": [],
   "source": [
    "from jax.scipy.special import logsumexp\n",
    "\n",
    "def relu(x):\n",
    "  return jnp.maximum(0, x)\n",
    "\n",
    "def predict(params, image):\n",
    "  # per-example predictions\n",
    "  activations = image\n",
    "  for w, b in params[:-1]:\n",
    "    outputs = jnp.dot(w, activations) + b\n",
    "    activations = relu(outputs)\n",
    "\n",
    "  final_w, final_b = params[-1]\n",
    "  logits = jnp.dot(final_w, activations) + final_b\n",
    "  return logits - logsumexp(logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dRW_TvCTWgaP"
   },
   "source": [
    "Let's check that our prediction function only works on single images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "4sW2A5mnXHc5",
    "outputId": "9d3b29e8-fab3-4ecb-9f63-bc8c092f9006"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10,)\n"
     ]
    }
   ],
   "source": [
    "# This works on single examples\n",
    "random_flattened_image = random.normal(random.key(1), (28 * 28,))\n",
    "preds = predict(params, random_flattened_image)\n",
    "print(preds.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "PpyQxuedXfhp",
    "outputId": "d5d20211-b6da-44e9-f71e-946f2a9d0fc4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Invalid shapes!\n"
     ]
    }
   ],
   "source": [
    "# Doesn't work with a batch\n",
    "random_flattened_images = random.normal(random.key(1), (10, 28 * 28))\n",
    "try:\n",
    "  preds = predict(params, random_flattened_images)\n",
    "except TypeError:\n",
    "  print('Invalid shapes!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "oJOOncKMXbwK",
    "outputId": "31285fab-7667-4871-fcba-28e86adc3fc6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 10)\n"
     ]
    }
   ],
   "source": [
    "# Let's upgrade it to handle batches using `vmap`\n",
    "\n",
    "# Make a batched version of the `predict` function\n",
    "batched_predict = vmap(predict, in_axes=(None, 0))\n",
    "\n",
    "# `batched_predict` has the same call signature as `predict`\n",
    "batched_preds = batched_predict(params, random_flattened_images)\n",
    "print(batched_preds.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "elsG6nX03BvW"
   },
   "source": [
    "At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NwDuFqc9X7ER"
   },
   "source": [
    "## Utility and loss functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "6lTI6I4lWdh5"
   },
   "outputs": [],
   "source": [
    "def one_hot(x, k, dtype=jnp.float32):\n",
    "  \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
    "  return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
    "\n",
    "def accuracy(params, images, targets):\n",
    "  target_class = jnp.argmax(targets, axis=1)\n",
    "  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
    "  return jnp.mean(predicted_class == target_class)\n",
    "\n",
    "def loss(params, images, targets):\n",
    "  preds = batched_predict(params, images)\n",
    "  return -jnp.mean(preds * targets)\n",
    "\n",
    "@jit\n",
    "def update(params, x, y):\n",
    "  grads = grad(loss)(params, x, y)\n",
    "  return [(w - step_size * dw, b - step_size * db)\n",
    "          for (w, b), (dw, db) in zip(params, grads)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umJJGZCC2oKl"
   },
   "source": [
    "## Data loading with PyTorch\n",
    "\n",
    "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "gEvWt8_u2pqG",
    "outputId": "2c83a679-9ce5-4c67-bccb-9ea835a8eaf6"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
      "  pid, fd = os.forkpty()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (2.4.1)\n",
      "Requirement already satisfied: torchvision in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (0.19.1)\n",
      "Requirement already satisfied: filelock in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.16.0)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (1.13.2)\n",
      "Requirement already satisfied: networkx in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.3)\n",
      "Requirement already satisfied: jinja2 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (2024.9.0)\n",
      "Requirement already satisfied: setuptools in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (73.0.1)\n",
      "Requirement already satisfied: numpy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (1.26.4)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (10.4.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from sympy->torch) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cellView": "both",
    "id": "94PjXZ8y3dVF"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from jax.tree_util import tree_map\n",
    "from torch.utils.data import DataLoader, default_collate\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "def numpy_collate(batch):\n",
    "  \"\"\"\n",
    "  Collate function specifies how to combine a list of data samples into a batch.\n",
    "  default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.\n",
    "  \"\"\"\n",
    "  return tree_map(np.asarray, default_collate(batch))\n",
    "\n",
    "def flatten_and_cast(pic):\n",
    "  \"\"\"Convert PIL image to flat (1-dimensional) numpy array.\"\"\"\n",
    "  return np.ravel(np.array(pic, dtype=jnp.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "l314jsfP4TN4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "HTTP Error 404: Not Found\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100.0%"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Define our dataset, using torch datasets\n",
    "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast)\n",
    "# Create pytorch data loader with custom collate function\n",
    "training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "FTNo4beUvb6t",
    "outputId": "65a9087c-c326-49e5-cbfc-e0839212fa31"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:76: UserWarning: train_data has been renamed data\n",
      "  warnings.warn(\"train_data has been renamed data\")\n",
      "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:66: UserWarning: train_labels has been renamed targets\n",
      "  warnings.warn(\"train_labels has been renamed targets\")\n",
      "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:81: UserWarning: test_data has been renamed data\n",
      "  warnings.warn(\"test_data has been renamed data\")\n",
      "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:71: UserWarning: test_labels has been renamed targets\n",
      "  warnings.warn(\"test_labels has been renamed targets\")\n"
     ]
    }
   ],
   "source": [
    "# Get the full train dataset (for checking accuracy while training)\n",
    "train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)\n",
    "train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)\n",
    "\n",
    "# Get full test dataset\n",
    "mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)\n",
    "test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)\n",
    "test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xxPd6Qw3Z98v"
   },
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "X2DnZo3iYj18",
    "outputId": "0eba3ca2-24a1-4cba-aaf4-3ac61d0c650e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 in 5.53 sec\n",
      "Training set accuracy 0.9156666994094849\n",
      "Test set accuracy 0.9199000000953674\n",
      "Epoch 1 in 1.13 sec\n",
      "Training set accuracy 0.9370499849319458\n",
      "Test set accuracy 0.9383999705314636\n",
      "Epoch 2 in 1.12 sec\n",
      "Training set accuracy 0.9490833282470703\n",
      "Test set accuracy 0.9467999935150146\n",
      "Epoch 3 in 1.21 sec\n",
      "Training set accuracy 0.9568833708763123\n",
      "Test set accuracy 0.9532999992370605\n",
      "Epoch 4 in 1.17 sec\n",
      "Training set accuracy 0.9631666541099548\n",
      "Test set accuracy 0.9574999809265137\n",
      "Epoch 5 in 1.17 sec\n",
      "Training set accuracy 0.9675000309944153\n",
      "Test set accuracy 0.9615999460220337\n",
      "Epoch 6 in 1.11 sec\n",
      "Training set accuracy 0.9709500074386597\n",
      "Test set accuracy 0.9652999639511108\n",
      "Epoch 7 in 1.17 sec\n",
      "Training set accuracy 0.9736999869346619\n",
      "Test set accuracy 0.967199981212616\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "  start_time = time.time()\n",
    "  for x, y in training_generator:\n",
    "    y = one_hot(y, n_targets)\n",
    "    params = update(params, x, y)\n",
    "  epoch_time = time.time() - start_time\n",
    "\n",
    "  train_acc = accuracy(params, train_images, train_labels)\n",
    "  test_acc = accuracy(params, test_images, test_labels)\n",
    "  print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n",
    "  print(\"Training set accuracy {}\".format(train_acc))\n",
    "  print(\"Test set accuracy {}\".format(test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xC1CMcVNYwxm"
   },
   "source": [
    "We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.\n",
    "We used NumPy to specify all of our computation, and borrowed the great data loaders from PyTorch, and ran the whole thing on the GPU."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Neural Network and Data Loading.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
